import numpy as np
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from hypersense.sampler.base_sampler import BaseSampler


class KMeansCentroidSampler(BaseSampler):
    """
    Sample data points closest to each KMeans cluster centroid.
    Useful for selecting a representative and diverse subset.
    """

    def __init__(self, dataset, sample_size, seed=42, scale=True, **kwargs):
        super().__init__(dataset, sample_size, seed, **kwargs)
        self.scale = scale
        self.sampled_indices = []  # Will store indices after sampling

    def sample(self):
        data = np.array(self.dataset)

        if self.scale:
            data = StandardScaler().fit_transform(data)

        if self.sample_size > len(data):
            raise ValueError("Sample size exceeds dataset size")

        kmeans = KMeans(n_clusters=self.sample_size, random_state=self.seed, n_init='auto')
        kmeans.fit(data)
        centers = kmeans.cluster_centers_
        labels = kmeans.labels_

        sampled_indices = []
        for i in range(self.sample_size):
            cluster_points = np.where(labels == i)[0]
            cluster_data = data[cluster_points]
            centroid = centers[i]
            distances = np.linalg.norm(cluster_data - centroid, axis=1)
            closest_index = cluster_points[np.argmin(distances)]
            sampled_indices.append(closest_index)

        self.sampled_indices = sampled_indices  # Save for plot()
        self._data_used = data  # Save transformed data for plotting if needed

        return [self.dataset[i] for i in sampled_indices]

    def plot(self, annotate: bool = True):
        """
        Visualize 2D projection of dataset and sampled points using PCA.

        Args:
            annotate (bool): Whether to show index labels for sampled points.
        """
        if not self.sampled_indices:
            raise RuntimeError("Please call sample() before plotting.")

        data = self._data_used  # Either original or scaled, used in sampling

        # Reduce to 2D using PCA for plotting
        pca = PCA(n_components=2, random_state=self.seed)
        data_2d = pca.fit_transform(data)

        subset_2d = data_2d[self.sampled_indices]

        plt.figure(figsize=(8, 6))
        plt.scatter(data_2d[:, 0], data_2d[:, 1], label="All Data", alpha=0.3, s=20)
        plt.scatter(subset_2d[:, 0], subset_2d[:, 1], label="Sampled Points", color='red', s=50)

        if annotate:
            for i, (x, y) in enumerate(subset_2d):
                plt.text(x, y, str(i), fontsize=8, ha='center', va='center', color='black')

        plt.title("KMeans Centroid Sampling Visualization")
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.show()
